# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Query Detector.
"""
import hashlib
from collections import Counter
from multiprocessing import Pool
import numpy as np
import cv2

from mindarmour.utils.logger import LogUtil
from mindarmour import Detector

np.set_printoptions(threshold=np.inf)

LOGGER = LogUtil.get_instance()
TAG = 'QueryDetector'


class QueryDetector(Detector):
    """
    QueryDetector detects highly similar queries generated by iterative optimization using probabilistic fingerprints,
    a compact hash representation computed for each input query.

    Reference: `Blacklight: Scalable Defense for Neural Networks against Query-Based Black-Box Attacks.
    Huiying Li, Shawn Shan, Emily Wenger, Jiayun Zhang, Haitao Zheng, Ben Y. Zhao. 2022.
    In Proceedings of 31th USENIX Security Symposium (USENIX Security'2022).`_

    Args:
        sample_query (numpy.ndarray): an example query. The received queries should be with
         the same shape as sample_query.
        window_size (int): the size of sliding window to split each query.
            Default: 50.
        step_size (int): the step between adjacent sliding windows.
            Default: 1.
        roundto (int): the quantization step of each hash.
            Default: 50.
        hash_kept (int): number of hash fingerprints kept to represent each query.
            Default: 50.
        threshold (int): number of matched hashes to decide a query is an attack.
            Default: 25.
        workers (int): number of processes to compute hash fingerprints for the input queries.
            Default: 5.
        salt (numpy.ndarray): a randomly generated salt image (of the same dimensions as sample_query),
         to improves robustness against adaptive attacks.
            Default: None.
    Examples:
        >>> from query_detector import QueryDetector
        >>> import numpy as np
        >>> np.random.seed(5)
        >>> benign_queries = np.random.randint(0, 255, [10, 224, 224, 3], np.uint8)
        >>> suspicious_queries = benign_queries[1] + np.random.rand(20, 224, 224, 3)
        >>> suspicious_queries = suspicious_queries.astype(np.uint8)
        >>> detector = QueryDetector(suspicious_queries[0])
        >>> detector.detect(suspicious_queries)
    """

    def __init__(self, sample_query, window_size=50, step_size=1, roundto=50, hash_kept=50, threshold=25,
                 workers=5, salt=None):
        super(QueryDetector, self).__init__()
        self.window_size = window_size
        self.step_size = step_size
        self.roundto = roundto
        self.num_hashes_keep = hash_kept
        self.threshold = threshold
        self.workers = workers
        self._detected_queries = []

        if salt is not None:
            self.salt = salt
        else:
            self.salt = np.random.rand(*sample_query.shape) * 255.

        self.hash_dict = {}
        self.output = {}
        self.input_idx = 0
        self.pool = Pool(processes=workers)

    def preprocess(self, array, roundto=1, normalized=False):
        """
        Perform salted pixel quantization, converting continuous pixel values into a finite set of discrete values.
        Quantization increases similarity between (attack) queries.
        Args:
            array (numpy.ndarray): an input query
            roundto (int): the quantization step of each hash.
            normalized (boolean): True for images within [0,1], False for iamges within [0,255]
        Returns:
            the query after quantization.
        """
        if normalized:
            array = np.array(array) * 255.
        array = (array + self.salt) % 255.
        array = array.reshape(-1)
        array = np.around(array / roundto, decimals=0) * roundto
        array = array.astype(np.int16)
        return array

    def hash_img(self, img, window_size, roundto, step_size, preprocess=True):
        """
        Compute hash fingerprint for the query image.
            Args:
                img (numpy.ndarray): the query image.
                window_size (int): the size of sliding window to split each query.
                roundto (int): the quantization step of each hash.
                step_size (int): the step between adjacent sliding windows.
                preprocess (boolean): whether to perform salted pixel quantization.
            Returns:
                a list of hashes of the query image as its fingerprint.
        """
        if preprocess:
            img = self.preprocess(img, roundto)

        total_len = int(len(img))
        idx_ls = []
        for el in range(int((total_len - window_size + 1) / step_size)):
            idx_ls.append({"idx": el * step_size, "img": img, "window_size": window_size})
        hash_list = self.pool.map(hash_helper, idx_ls)
        hash_list = list(set(hash_list))
        hash_list = [r[::-1] for r in hash_list]
        hash_list.sort(reverse=True)
        return hash_list

    def check_img(self, hashes):
        """
        Check the matched hashes of current query with all previous queries.
            Args:
                hashes (list): the hash list of a query image.
            Returns:
                the number of matched hashes of current query with all previous queries.
        """
        sets = list(map(self.hash_dict.get, hashes))
        sets = [i for i in sets if i is not None]
        sets = [item for sublist in sets for item in sublist]
        if not sets:
            return 0
        sets = Counter(sets)
        cnt = sets.most_common(1)[0][1]
        return cnt

    def add_img(self, img):
        """
        Add the hashes of current query to the dict recording all the hashes, and compute the matched hashes with
        previous queries.
            Args:
                img (numpy.ndarray): the query image.
            Returns:
                the number of matched hashes of current query with all previous queries.
        """
        self.input_idx += 1
        hashes = self.hash_img(img, self.window_size, self.roundto, self.step_size)[:self.num_hashes_keep]
        cnt = self.check_img(hashes)
        for el in hashes:
            if el not in self.hash_dict:
                self.hash_dict[el] = [self.input_idx]
            else:
                self.hash_dict.get(el).append(self.input_idx)
        return cnt

    def fit(self, inputs, labels=None):
        """
        Process input training data to calculate the threshold.
        A proper threshold should make sure the false positive
        rate is under a given value.

        Args:
            inputs (Union[numpy.ndarray, list, tuple]): Data been used as
                references to create adversarial examples.
        Raises:
            NotImplementedError: This function is not available
                in class `QueryDetector`.
        """
        msg = 'The function fit() is not available in the class ' \
              '`QueryDetector`.'
        LOGGER.error(TAG, msg)
        raise NotImplementedError(msg)

    def detect(self, queries):
        """
        Process queries to detect black-box attack.
        Args:
             queries (numpy.ndarray): Query sequence.
        Returns:
             - list[int], ids of adversarial examples detected.
        """
        queries = np.array(queries)
        match_list = []
        adv_ids = []
        query_id = 0
        for query in queries:
            query = self.gaussian_filter(query)
            match_num = self.add_img(query)
            match_list.append(match_num)
            LOGGER.info(TAG, 'Image: {}, max match: {}, attack_query: {}'.format(query_id, match_num,
                                                                                 match_num > self.threshold))
            if match_num > self.threshold:
                adv_ids.append(query_id)
            query_id += 1

        num = np.sum([1 for i in match_list if i > self.threshold])
        LOGGER.info(TAG, "positive num on test dataset: {}".format(num))
        rate = num / float(len(queries))
        LOGGER.info(TAG, "positive rate: {}".format(rate))
        return adv_ids

    def gaussian_filter(self, img):
        """
        Filter the noise of query image before fed to the detector.
            Args:
                img (numpy.ndarray): the query image.
            Returns:
                the query filtered.
        """
        img = cv2.GaussianBlur(src=img, ksize=(5, 5), sigmaX=0, sigmaY=1.0)
        return img

    def detect_diff(self, inputs):
        """
        Detect adversarial samples from input samples, like the predict_proba function in common machine learning model.

        Args:
            inputs (Union[numpy.ndarray, list, tuple]): Data been used as
                references to create adversarial examples.
        Raises:
            NotImplementedError: This function is not available
                in class `QueryDetector`.
        """
        msg = 'The function detect_diff() is not available in the class ' \
              '`QueryDetector`.'
        LOGGER.error(TAG, msg)
        raise NotImplementedError(msg)

    def transform(self, inputs):
        """
        Filter adversarial noises in input samples.
        Args:
            inputs (Union[numpy.ndarray, list, tuple]): Data been used as references to create adversarial examples.
        Raises:
            NotImplementedError: This function is not available in class `QueryDetector`.
        """
        msg = 'The function transform() is not available in the class ' \
              '`QueryDetector`.'
        LOGGER.error(TAG, msg)
        raise NotImplementedError(msg)

    def delete(self):
        """
           Close the pool for computing hashes.
        """
        self.pool.close()


def hash_helper(arguments):  # compute hashes
    """
    Compute hash for a part of image split by the sliding window.
        Args:
            arguments (dict):  containing a specific window of a query image, the window id and the window size.
        Returns:
            the hash for the current part of query image.
    """
    img = arguments['img']
    idx = arguments['idx']
    window_size = arguments['window_size']
    return hashlib.sha256(img[idx:idx + window_size]).hexdigest()
